////////////////////////////////////////////////////////////
//
// SFML - Simple and Fast Multimedia Library
// Copyright (C) 2007-2009 Laurent Gomila (laurent.gom@gmail.com)
//
// This software is provided 'as-is', without any express or implied warranty.
// In no event will the authors be held liable for any damages arising from the use of this software.
//
// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it freely,
// subject to the following restrictions:
//
// 1. The origin of this software must not be misrepresented;
//    you must not claim that you wrote the original software.
//    If you use this software in a product, an acknowledgment
//    in the product documentation would be appreciated but is not required.
//
// 2. Altered source versions must be plainly marked as such,
//    and must not be misrepresented as being the original software.
//
// 3. This notice may not be removed or altered from any source distribution.
//
////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////
// Headers
////////////////////////////////////////////////////////////
#include <SFML/Network/TcpSocket.hpp>
#include <SFML/Network/IpAddress.hpp>
#include <SFML/Network/Packet.hpp>
#include <SFML/Network/SocketImpl.hpp>
#include <SFML/System/Err.hpp>
#include <algorithm>
#include <cstring>

#ifdef _MSC_VER
    #pragma warning(disable : 4127) // "conditional expression is constant" generated by the FD_SET macro
#endif


namespace sf
{
////////////////////////////////////////////////////////////
TcpSocket::TcpSocket() :
Socket(Tcp)
{

}


////////////////////////////////////////////////////////////
unsigned short TcpSocket::GetLocalPort() const
{
    if (GetHandle() != priv::SocketImpl::InvalidSocket())
    {
        // Retrieve informations about the local end of the socket
        sockaddr_in address;
        priv::SocketImpl::AddrLength size = sizeof(address);
        if (getsockname(GetHandle(), reinterpret_cast<sockaddr*>(&address), &size) != -1)
        {
            return ntohs(address.sin_port);
        }
    }

    // We failed to retrieve the port
    return 0;
}


////////////////////////////////////////////////////////////
IpAddress TcpSocket::GetRemoteAddress() const
{
    if (GetHandle() != priv::SocketImpl::InvalidSocket())
    {
        // Retrieve informations about the remote end of the socket
        sockaddr_in address;
        priv::SocketImpl::AddrLength size = sizeof(address);
        if (getpeername(GetHandle(), reinterpret_cast<sockaddr*>(&address), &size) != -1)
        {
            return IpAddress(ntohl(address.sin_addr.s_addr));
        }
    }

    // We failed to retrieve the address
    return IpAddress::None;
}


////////////////////////////////////////////////////////////
unsigned short TcpSocket::GetRemotePort() const
{
    if (GetHandle() != priv::SocketImpl::InvalidSocket())
    {
        // Retrieve informations about the remote end of the socket
        sockaddr_in address;
        priv::SocketImpl::AddrLength size = sizeof(address);
        if (getpeername(GetHandle(), reinterpret_cast<sockaddr*>(&address), &size) != -1)
        {
            return ntohs(address.sin_port);
        }
    }

    // We failed to retrieve the port
    return 0;
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::Connect(const IpAddress& remoteAddress, unsigned short remotePort, Uint32 timeout)
{
    // Create the internal socket if it doesn't exist
    Create();

    // Create the remote address
    sockaddr_in address = priv::SocketImpl::CreateAddress(remoteAddress.ToInteger(), remotePort);

    if (timeout == 0)
    {
        // ----- We're not using a timeout: just try to connect -----

        // Connect the socket
        if (connect(GetHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) == -1)
            return priv::SocketImpl::GetErrorStatus();

        // Connection succeeded
        return Done;
    }
    else
    {
        // ----- We're using a timeout: we'll need a few tricks to make it work -----

        // Save the previous blocking state
        bool blocking = IsBlocking();

        // Switch to non-blocking to enable our connection timeout
        if (blocking)
            SetBlocking(false);

        // Try to connect to the remote address
        if (connect(GetHandle(), reinterpret_cast<sockaddr*>(&address), sizeof(address)) >= 0)
        {
            // We got instantly connected! (it may no happen a lot...)
            return Done;
        }

        // Get the error status
        Status status = priv::SocketImpl::GetErrorStatus();

        // If we were in non-blocking mode, return immediatly
        if (!blocking)
            return status;

        // Otherwise, wait until something happens to our socket (success, timeout or error)
        if (status == Socket::NotReady)
        {
            // Setup the selector
            fd_set selector;
            FD_ZERO(&selector);
            FD_SET(GetHandle(), &selector);

            // Setup the timeout
            timeval time;
            time.tv_sec  = timeout / 1000;
            time.tv_usec = (timeout - time.tv_sec * 1000) * 1000;

            // Wait for something to write on our socket (which means that the connection request has returned)
            if (select(static_cast<int>(GetHandle() + 1), NULL, &selector, NULL, &time) > 0)
            {
                // At this point the connection may have been either accepted or refused.
                // To know whether it's a success or a failure, we must check the address of the connected peer
                if (GetRemoteAddress() != sf::IpAddress::None)
                {
                    // Connection accepted
                    status = Done;
                }
                else
                {
                    // Connection refused
                    status = priv::SocketImpl::GetErrorStatus();
                }
            }
            else
            {
                // Failed to connect before timeout is over
                status = priv::SocketImpl::GetErrorStatus();
            }
        }

        // Switch back to blocking mode
        SetBlocking(true);

        return status;
    }
}


////////////////////////////////////////////////////////////
void TcpSocket::Disconnect()
{
    // Simply close the socket
    Close();
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::Send(const char* data, std::size_t size)
{
    // Check the parameters
    if (!data || (size == 0))
    {
        Err() << "Cannot send data over the network (no data to send)" << std::endl;
        return Error;
    }

    // Loop until every byte has been sent
    int sent = 0;
    int sizeToSend = static_cast<int>(size);
    for (int length = 0; length < sizeToSend; length += sent)
    {
        // Send a chunk of data
        sent = send(GetHandle(), data + length, sizeToSend - length, 0);

        // Check for errors
        if (sent < 0)
            return priv::SocketImpl::GetErrorStatus();
    }

    return Done;
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::Receive(char* data, std::size_t size, std::size_t& received)
{
    // First clear the variables to fill
    received = 0;

    // Check the destination buffer
    if (!data)
    {
        Err() << "Cannot receive data from the network (the destination buffer is invalid)" << std::endl;
        return Error;
    }

    // Receive a chunk of bytes
    int sizeReceived = recv(GetHandle(), data, static_cast<int>(size), 0);

    // Check the number of bytes received
    if (sizeReceived > 0)
    {
        received = static_cast<std::size_t>(sizeReceived);
        return Done;
    }
    else if (sizeReceived == 0)
    {
        return Socket::Disconnected;
    }
    else
    {
        return priv::SocketImpl::GetErrorStatus();
    }
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::Send(Packet& packet)
{
    // TCP is a stream protocol, it doesn't preserve messages boundaries.
    // This means that we have to send the packet size first, so that the
    // receiver knows the actual end of the packet in the data stream.

    // Get the data to send from the packet
    std::size_t size = 0;
    const char* data = packet.OnSend(size);

    // First send the packet size
    Uint32 packetSize = htonl(static_cast<unsigned long>(size));
    Status status = Send(reinterpret_cast<const char*>(&packetSize), sizeof(packetSize));

    // Make sure that the size was properly sent
    if (status != Done)
        return status;

    // Send the packet data
    if (packetSize > 0)
    {
        return Send(data, size);
    }
    else
    {
        return Done;
    }
}


////////////////////////////////////////////////////////////
Socket::Status TcpSocket::Receive(Packet& packet)
{
    // First clear the variables to fill
    packet.Clear();

    // We start by getting the size of the incoming packet
    Uint32 packetSize = 0;
    std::size_t received = 0;
    if (myPendingPacket.SizeReceived < sizeof(myPendingPacket.Size))
    {
        // Loop until we've received the entire size of the packet
        // (even a 4 bytes variable may be received in more than one call)
        while (myPendingPacket.SizeReceived < sizeof(myPendingPacket.Size))
        {
            char* data = reinterpret_cast<char*>(&myPendingPacket.Size) + myPendingPacket.SizeReceived;
            Status status = Receive(data, sizeof(myPendingPacket.Size) - myPendingPacket.SizeReceived, received);
            myPendingPacket.SizeReceived += received;

            if (status != Done)
                return status;
        }

        // The packet size has been fully received
        packetSize = ntohl(myPendingPacket.Size);
    }
    else
    {
        // The packet size has already been received in a previous call
        packetSize = ntohl(myPendingPacket.Size);
    }

    // Loop until we receive all the packet data
    char buffer[1024];
    while (myPendingPacket.Data.size() < packetSize)
    {
        // Receive a chunk of data
        std::size_t sizeToGet = std::min(static_cast<std::size_t>(packetSize - myPendingPacket.Data.size()), sizeof(buffer));
        Status status = Receive(buffer, sizeToGet, received);
        if (status != Done)
            return status;

        // Append it into the packet
        if (received > 0)
        {
            myPendingPacket.Data.resize(myPendingPacket.Data.size() + received);
            char* begin = &myPendingPacket.Data[0] + myPendingPacket.Data.size() - received;
            std::memcpy(begin, buffer, received);
        }
    }

    // We have received all the packet data: we can copy it to the user packet
    if (!myPendingPacket.Data.empty())
        packet.OnReceive(&myPendingPacket.Data[0], myPendingPacket.Data.size());

    // Clear the pending packet data
    myPendingPacket = PendingPacket();

    return Done;
}

} // namespace sf
